Closest binary search tree value II

Time: O(H+K); Space: O(H); hard

Given a non-empty binary search tree and a target value, find k values in the BST that are closest to the target.

Notes:

  1. Given target value is a floating point.

  2. You may assume k is always valid, that is: k≤ total nodes.

  3. You are guaranteed to have only one unique set of k values in the BST that are closest to the target.

Example 1:

    4
   / \
  2   5
 / \
1   3

Input: root = [4,2,5,1,3], target = 3.714286, k = 2

Output: [4,3]

Follow up:

  1. Assume that the BST is balanced, could you solve it in less than O(n) runtime (where n = total nodes)?

[1]:
class TreeNode(object):
    '''
    Definition for a binary tree (BST) node
    '''
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None
[2]:
class Solution1(object):
    def closestKValues(self, root, target, k):
        """
        :type root: TreeNode
        :type target: float
        :type k: int
        :rtype: List[int]
        """
        # Helper to make a stack to the next node.
        def nextNode(stack, child1, child2):
            if stack:
                if child2(stack):
                    stack.append(child2(stack))
                    while child1(stack):
                        stack.append(child1(stack))
                else:
                    child = stack.pop()
                    while stack and child is child2(stack):
                        child = stack.pop()

        # The forward or backward iterator.
        backward = lambda stack: stack[-1].left
        forward = lambda stack: stack[-1].right

        # Build the stack to the closest node.
        stack = []
        while root:
            stack.append(root)
            root = root.left if target < root.val else root.right
        dist = lambda node: abs(node.val - target)
        forward_stack = stack[:stack.index(min(stack, key=dist))+1]

        # Get the stack to the next smaller node.
        backward_stack = list(forward_stack)
        nextNode(backward_stack, backward, forward)

        # Get the closest k values by advancing the iterators of the stacks.
        result = []
        for _ in range(k):
            if forward_stack and \
                (not backward_stack or dist(forward_stack[-1]) < dist(backward_stack[-1])):
                result.append(forward_stack[-1].val)
                nextNode(forward_stack, forward, backward)
            elif backward_stack and \
                (not forward_stack or dist(backward_stack[-1]) <= dist(forward_stack[-1])):
                result.append(backward_stack[-1].val)
                nextNode(backward_stack, backward, forward)
        return result
[3]:
s = Solution1()
root = TreeNode(4)
root.left, root.right = TreeNode(2), TreeNode(5)
root.left.left, root.left.right = TreeNode(1), TreeNode(3)
target = 3.714286
k = 2
assert s.closestKValues(root, target, k) == [4, 3]
[4]:
class Solution2(object):
    def closestKValues(self, root, target, k):
        """
        :type root: TreeNode
        :type target: float
        :type k: int
        :rtype: List[int]
        """
        # Helper class to make a stack to the next node.
        class BSTIterator:
            # @param root, a binary search tree's root node
            def __init__(self, stack, child1, child2):
                self.stack = list(stack)
                self.cur = self.stack.pop()
                self.child1 = child1
                self.child2 = child2

            # @return an integer, the next node
            def next(self):
                node = None
                if self.cur and self.child1(self.cur):
                    self.stack.append(self.cur)
                    node = self.child1(self.cur)
                    while self.child2(node):
                        self.stack.append(node)
                        node = self.child2(node)
                elif self.stack:
                    prev = self.cur
                    node = self.stack.pop()
                    while node:
                        if self.child2(node) is prev:
                            break
                        else:
                            prev = node
                            node = self.stack.pop() if self.stack else None
                self.cur = node
                return node

        # Build the stack to the closet node.
        stack = []
        while root:
            stack.append(root)
            root = root.left if target < root.val else root.right
        dist = lambda node: abs(node.val - target) if node else float("inf")
        stack = stack[:stack.index(min(stack, key=dist)) + 1]

        # The forward or backward iterator.
        backward = lambda node: node.left
        forward = lambda node: node.right
        smaller_it, larger_it = BSTIterator(stack, backward, forward), BSTIterator(stack, forward, backward)
        smaller_node, larger_node = smaller_it.next(), larger_it.next()

        # Get the closest k values by advancing the iterators of the stacks.
        result = [stack[-1].val]
        for _ in range(k - 1):
            if dist(smaller_node) < dist(larger_node):
                result.append(smaller_node.val)
                smaller_node = smaller_it.next()
            else:
                result.append(larger_node.val)
                larger_node = larger_it.next()
        return result
[5]:
s = Solution2()
root = TreeNode(4)
root.left, root.right = TreeNode(2), TreeNode(5)
root.left.left, root.left.right = TreeNode(1), TreeNode(3)
target = 3.714286
k = 2
assert s.closestKValues(root, target, k) == [4, 3]